import itertools
import os
import re
import pandas as pd
from pathlib import Path
from datetime import datetime
from multiprocessing import Pool
from typing import Sequence, Tuple, Dict, Union, Set, Iterable

import matplotlib.pyplot as plt
import numpy as np
import progressbar
import seaborn as sns
import ujson

from .config import EXPORT_DIRECTORY
from .data_management.geo import extract_country, extract_isp, extract_isp_stats, extract_day_night, \
    extract_day_night_movement, extract_continent, continent_list, nb_of_continent

buckets = ['up', 'down', 'pending', 'waiting']
PERCENTILE_ = 0


def extract_pending(data):
    y = sorted([v['Pending'] for v in itertools.chain.from_iterable(data.values()) if 'Pending' in v])
    begin = datetime.strptime(y[0], '%Y-%m-%d %H:%M:%S.%f')
    y = [(datetime.strptime(k, '%Y-%m-%d %H:%M:%S.%f') - begin).total_seconds() for k in y]
    return range(len(y)), y


def extract_pending_per_s(data):
    x, y = extract_pending(data)
    x = [len([b for b in y if a <= b < a + 1]) for a in range(int(y[-1]) + 2)]
    return pd.DataFrame(x, columns=["number of nodes"])


def extract_up(data):
    y = sorted([v['Up'] for v in data['up'] if 'Up' in v])
    begin = datetime.strptime(y[0], '%Y-%m-%d %H:%M:%S.%f')
    y = [(datetime.strptime(k, '%Y-%m-%d %H:%M:%S.%f') - begin).total_seconds() for k in y]
    return range(len(y)), y


def print_country_data(x, y, title=None):
    plt.subplots_adjust(left=0.15, bottom=0, right=1, top=1, wspace=0, hspace=0)
    sns.set_context('paper')
    sns.barplot(x=y, y=x, orient='h')
    if title:
        plt.title(title)
    plt.show()


def pie(x, y, percentile, title=None):
    values, labels = cut_to_quartile(x, y, percentile) if percentile != 0 else (x, y)
    plt.pie(values, labels=labels, normalize=True)
    if title:
        if percentile != 0:
            title += f"({str(100 - percentile)}% most represented)"
        plt.title(title)
    plt.show()


def print_time_data(x, y, title=None, xlabel=None, ylabel=None):
    sns.set()
    plt.plot(x, y)
    plt.xlabel(xlabel if xlabel is not None else 'Time (s)')
    plt.xlabel(ylabel if ylabel is not None else 'Number of nodes')
    if title:
        plt.title(title)
    plt.show()


def cut_to_quartile(x, y, percentile_):
    return tuple(zip(*[(a, b) for a, b in zip(x, y) if a > round(np.percentile(x, percentile_))]))


def plot_by_time_of_day():
    x_coords = []
    y_coords = [[] for _ in range(nb_of_continent)]

    for filename in progressbar.progressbar([n for n in sorted(os.listdir(EXPORT_DIRECTORY)) if
                                             n.endswith(".json")]):  # for each file in EXPORT_DIRECTORY
        try:
            with open(EXPORT_DIRECTORY + "/" + filename, 'r') as ofi:
                data = ujson.load(ofi)  # find the datas in the json file
            x, y_data = extract_continent(data)

            # find the exact point for time axis

            hour, minute, second, mu_second = [int(filename[x:x + y]) for x, y in ((11, 2), (14, 2), (17, 2), (20, 3))]
            point = (hour + 1) % 24 + (minute + (second + mu_second / 1000) / 60) / 60

            if y_data != [0] * nb_of_continent:  # no geoip db present when crawl was achieved
                if point in x:  # point already exists, pretty rare but it may happen
                    index_ = x_coords.index(point)
                    for k in range(nb_of_continent):
                        y__ = (y_coords[k][index_] + y_data[
                            k]) / 2  # pretty bad way to reckon mean but it's so rare that it's ok
                        y_coords[k][index_] = y__
                else:
                    x_coords.append(point)
                    index_ = x_coords.index(point)
                    x_coords.sort()
                    for k in range(nb_of_continent):
                        y__ = y_coords[k]
                        y__.insert(index_, y_data[k])
                        y_coords[k] = y__
        except IsADirectoryError:
            pass
    for k in range(nb_of_continent):
        plt.plot(x_coords, y_coords[k], label=continent_list[k])
    plt.xlabel('Time (h)')
    plt.ylabel('Number of nodes')
    plt.legend()
    plt.show()


def grouper(iterable: Iterable, n: int, fillvalue=None) -> itertools.zip_longest:
    args = [iter(iterable)] * n
    return itertools.zip_longest(*args, fillvalue=fillvalue)


def aggregate(filenames: Sequence[str]) -> Dict[str, Union[datetime, Set[Tuple[str, str]]]]:
    out = {"time": datetime.strptime(filenames[0][:19], '%Y-%m-%d %H:%M:%S'), "data": set()}
    for filename in filenames:
        if filename is not None:
            with open(EXPORT_DIRECTORY + "/" + filename, "r") as file:
                for node in ujson.load(file)['up']:
                    out["data"].add((node['IP address'], node['UDP port']))
    return out


def compute_churn(pair: Tuple[
    Dict[str, Union[datetime, Set[Tuple[str, str]]]], Dict[str, Union[datetime, Set[Tuple[str, str]]]]]) -> Tuple[
    datetime, float, float]:
    timestamp = pair[1]["time"]

    data = pair[1]["data"]
    data_before = pair[0]["data"]

    entering = (len(data - data_before) / len(data)) * 100
    leaving = (len(data_before - data) / len(data_before)) * 100
    return timestamp, entering, leaving


def plot_churn(group_size: int):
    filename_regex = r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3}\.json\Z"
    names = [n for n in sorted(os.listdir(EXPORT_DIRECTORY)) if re.match(filename_regex, n)]
    with Pool() as pool:
        data = pool.map(aggregate, grouper(names, group_size))
        x, entering, leaving = list(zip(*pool.map(compute_churn, zip(data, data[1:]))))

    sns.set()
    plt.rcParams["figure.figsize"] = (18.5, 10.5)
    plt.plot(x, leaving, label='Nodes leaving the network')
    plt.plot(x, entering, label='Nodes (re)joining the network')

    # finding y limit for display
    max_y = 100 if max(leaving) + 15 > 100 or max(entering) + 15 > 100 else max(max(leaving) + 15, max(entering) + 15)
    plt.axis([x[0], x[-1], 0, max_y])

    plt.xticks(x, labels=[dt.strftime("%m/%d, %H:%M:%S") for dt in x], rotation=45)
    plt.xlabel('Time (h)')
    plt.ylabel('Rate (%)')
    plt.title(f'Percentage of changes ({len(names)} crawls grouped by {group_size})')
    plt.subplots_adjust(left=0.125, bottom=0.32, right=0.9, top=0.88, wspace=0.2, hspace=0.2)
    plt.legend()
    plt.tight_layout()
    plt.savefig(
        f"{EXPORT_DIRECTORY}/churn_{x[0].strftime('%m%d%H%M%S')}_{x[-1].strftime('%m%d%H%M%S')}_{len(names)}_{group_size}.svg")
    plt.show()

    t, up = zip(*[(d["time"], len(d["data"])) for d in data])

    sns.set()
    plt.xticks(t, [ts.strftime("%m/%d, %H:%M:%S") for ts in t], rotation=45)
    plt.plot(t, up)
    plt.xlabel('Time (h)')
    plt.ylabel('Number of nodes')
    plt.title(f'Number of up nodes ({len(names)} crawls grouped by {group_size})')
    plt.savefig(
        f"{EXPORT_DIRECTORY}/ups_{x[0].strftime('%m%d%H%M%S')}_{x[-1].strftime('%m%d%H%M%S')}_{len(names)}_{group_size}.svg")
    plt.show()


def generate_graphs(path):  # in charge of generating graphs after crawl
    files = [Path(path) / n for n in sorted(os.listdir(path)) if n.endswith(".json")]
    filename = files[-1]
    print(f"from {filename}")
    with open(filename, "r") as f:
        data = ujson.load(f)

    """
    x, y = extract_info('pending', data)
    if x != [] and y != []:
        print_time_data(y, x, 'Total of discovered nodes in function of the elapsed time')

    x, y = extract_info('up_time', data)
    if x != [] and y != []:
        print_time_data(y, x, 'Total of up nodes in function of the elapsed time')
    """

    x1, y1 = extract_pending(data)
    x2, y2 = extract_up(data)

    if x1 and y1 and x2 and y2:
        sns.set()
        plt.plot(y1, x1)
        plt.plot(y2, x2)

        plt.xlabel('Time (s)')
        plt.ylabel('Number of nodes')
        plt.legend(["discovered nodes", "up nodes"])

        plt.show()

    df_pending_sec = extract_pending_per_s(data)
    if not df_pending_sec.empty:
        sns.set()
        sns.lineplot(x=df_pending_sec.index, y="number of nodes", data=df_pending_sec)
        plt.xlabel('Elapsed time (s)')
        plt.ylabel('Number of nodes')
        plt.show()

    # WARNING: requires location/ISP db:

    # x, y = extract_country(data)
    # if x != [] and y != []:
    #     print_country_data(x, y, 'Geographical repartition of nodes')
    #     pie(y, x, PERCENTILE_, 'Geographical repartition of nodes')

    # x, y = extract_isp(data)
    # if x != [] and y != []:
    #     pie(y, x, PERCENTILE_, 'ISP')

    # x, y = extract_isp_stats(data)
    # if (x != [] and y != []) and (y != [0, 0, 0]):
    #     pie(y, x, 0, 'ISP proportion')

    # x, y = extract_day_night(data)
    # if (x != [] and y != []) and (y != [0, 0]):
    #     pie(y, x, 0, 'Day/Night distribution')

    # x, y = extract_day_night_movement(data)
    # if (x != [] and y != []) and (y != [0]):
    #     print_time_data(x, y, 'Percentage of nodes in day time zones in function of the number of up nodes classified',
    #                     "Number of nodes", "Percentage")

    # x, y = extract_continent(data)
    # if (x != [] and y != []) and (y != [0, 0, 0, 0, 0, 0]):
    #     pie(y, x, 0, 'Continent distribution')
